import os.path
import sys
import time

import numpy as np
from torch import multiprocessing
from torch.utils.tensorboard.writer import SummaryWriter
from continual_rl.utils import utils
from continual_rl.utils.argparse_manager import ArgparseManager
from continual_rl.utils.metrics import Metrics

if __name__ == "__main__":
    start_time = time.time()
    # Pytorch multiprocessing requires either forkserver or spawn.
    try:
        multiprocessing.set_start_method("spawn")
    except ValueError as e:
        # Windows doesn't support forking, so fall back to spawn instead
        assert "cannot find context" in str(e)
        multiprocessing.set_start_method("spawn")

    experiment, policy = ArgparseManager.parse(sys.argv[1:])

    if experiment is None:
        raise RuntimeError("No experiment started. Most likely there is no new run to start.")

    utils.summary_writer = SummaryWriter(log_dir=experiment.output_dir)
    experiment.try_run(policy, summary_writer=utils.summary_writer)

    # -------------------计算CL指标-------------------
    # 计算CL指标 TODO 自动确定配置信息
    experiment_data_config = {
        'tag_base': 'eval_reward',
        'cache_dir': 'tmp/',
        'legend_size': 30,
        'title_size': 40,
        'axis_size': 20,
        'axis_label_size': 30,
        'exp_dir': os.path.dirname(experiment.output_dir),
        'models': {"Policy": {'name': str(policy), 'runs': [os.path.basename(experiment.output_dir)],
                              'color': 'rgba(77, 102, 133, 1)', 'color_alpha': 0.2}},
        'tasks': {"Task0": dict(i=0, eval_i=0, y_range=[0., 1], yaxis_dtick=0.2,
                                train_regions=[[3e6 * i, 3e6 * (i + 1)] for i in range(0, 3 * 1, 3)]),
                  "Task1": dict(i=1, eval_i=1, y_range=[0., 1], yaxis_dtick=0.2,
                                train_regions=[[3e6 * i, 3e6 * (i + 1)] for i in range(1, 3 * 1, 3)]),
                  "Task2": dict(i=2, eval_i=2, y_range=[0., 1], yaxis_dtick=0.2,
                                train_regions=[[3e6 * i, 3e6 * (i + 1)] for i in range(2, 3 * 1, 3)])},
        'rolling_mean_count': 10,
        'filter': 'ma',
        'num_cycles': 1,
        'num_cycles_for_forgetting': 1,
        'num_task_steps': 3e6,
        'gird_size': [2, 3],
        'which_exp': 'minigrid',
        'xaxis_tickvals': list(np.arange(0, 9e6 + 1, 0.5e6)),
        'cache_dir': '/tmp',
    }
    metrics = Metrics(experiment_data_config)
    forgetting_matrix, transfer_matrix, return_matrix = metrics.compute_single_metrics()
    experiment.logger.info(f"遗忘矩阵：\n{forgetting_matrix}\n\n")
    experiment.logger.info(f"迁移矩阵：\n{transfer_matrix}\n\n")
    experiment.logger.info(f"回报矩阵：\n{return_matrix}\n\n")

    # 计算平均遗忘、迁移、回报
    task_number = len(forgetting_matrix)
    mean_forgetting = np.sum(forgetting_matrix) / (task_number * (task_number - 1) / 2)
    mean_transfer = np.sum(transfer_matrix) / (task_number * (task_number - 1) / 2)
    mean_return = np.sum(return_matrix) / (task_number * task_number)
    hyperparameters = vars(policy.config)
    # 删除不需要的超参数
    del (hyperparameters['device'])
    del (hyperparameters['_output_dir'])
    # 记录超参数和CL指标
    utils.summary_writer.add_hparams(hyperparameters,
                                     {'hparam/mean_forgetting': mean_forgetting, 'hparam/mean_transfer': mean_transfer,
                                      'hparam/mean_return': mean_return})

    end_time = time.time()
    elapsed_time = end_time - start_time
    # 格式化为时分秒
    hour = int(elapsed_time / 60 / 60)
    mins = int(elapsed_time / 60 % 60)
    secs = int(elapsed_time % 60)
    experiment.logger.info(f"--------本次运行用时{hour}小时{mins}分{secs}秒---------")
